import logging
import os


import gym
import numpy as np
import torch

from agents.interfaces import LearnerAlone, Learner
from tools.envs import  make_dispatch
from tools.utils import size_action_space, preprocess, hot_encoding_index


def image_env(envs,index,goals,save_dir,spe=None,args=None,obs=None):
    img = envs.render(mode="wrapped" if (args and args.env_type=="multiworld" and args.image) else "rgb_array") if obs is None else obs

    goal_str = ""
    for i in range(index+1):
        goal_str += str(goals[0, i].item())

    from PIL import Image, ImageDraw, ImageFont
    path = save_dir + "/"+(spe if spe is not None else "")+ "index-" + str(index) + "-goal-" + goal_str + '-statistics.png'

    image = Image.fromarray(img )
    draw = ImageDraw.Draw(  image)
    font = ImageFont.truetype("resources/arial.ttf", 100)  # ImageFont.load_default()
    draw.text((40, 40), goal_str, fill=(255, 255, 255), font=font)
    image.save(path)

def generate_gwr_goals(coordpolicy,args):
    all_goals,all_index_goals=coordpolicy.generate_all()
    # if args.state:
    #     buffers = coordpolicy.context.rollouts.oegn.buffers
    #     all_goals2 = torch.zeros(all_goals.shape[0],args.num_latents)
    #     obses = torch.zeros(all_goals.shape[0],args.num_latents)
    #     for i in range(all_goals.shape[0]):
    #         all_goals2[i,:]
    #
    #     coordpolicy.context.coord_actor.act(obs["observation"], predefined_goal=goal_embed_obs, predefined_state=predefined_state,
    #                         state=state, step=s)

    all_goals=all_goals.reshape((all_goals.shape[0],-1))
    all_index_goals=all_index_goals.view((-1,1))
    all_index_goals2=torch.arange(all_goals.shape[0]).view((-1,1))
    all_index_encodes=hot_encoding_index((all_goals.shape[0]+1,),all_index_goals2,dim=1)
    return all_goals.detach(),all_index_goals,all_index_encodes,all_index_goals2


def multiworld_images(sublearner,save_dir,args,coordpolicy,goal_space,video=0,**kwargs):
    #There is a weird bug with videos, so we need to duplicate the code with and without videos

    policy_save_dir = save_dir + "/end"
    if not os.path.exists(save_dir): os.mkdir(save_dir)
    if not os.path.exists(policy_save_dir): os.mkdir(policy_save_dir)

    all_goals, all_index_goals,_,_= generate_gwr_goals(coordpolicy,args)

    goal_file=open(save_dir+"/test_goals.txt", "w")
    for row in all_goals:
        np.savetxt(goal_file, row)
    goal_file.close()

    #####Evaluation
    if not video:
        envs = make_dispatch(args.env_type, args, save_dir=save_dir,reset=video, name="video",video=video, **kwargs)
        obs = envs.reset()

    for i in range(all_goals.shape[0]):
        if video:
            envs = make_dispatch(args.env_type, args, save_dir=save_dir,reset=video, name="video"+str(all_index_goals[i].item()),video=video, **kwargs)
            obs = envs.reset()
        predefined_goal= all_goals[i:i+1].view(1,-1)
        if args.state:
            cluster = sublearner.mpolicy.rollouts.oegn.find_nearest_units(predefined_goal)[0]
            buffer = sublearner.mpolicy.rollouts.oegn.buffers[cluster]
            goal_obs = buffer.learnDataStore.sample_obs()
            predefined_goal = sublearner.mpolicy.context.estimator.goal_embed(preprocess(goal_obs,args),act=True)
        for s in range(args.max_steps):
            state = torch.tensor(obs["state"]) if args.state else None
            goal, _ = coordpolicy.act(obs["observation"],predefined_goal=predefined_goal,predefined_state=all_goals[i:i+1].view(1,-1),step=s,state=state)
            act=sublearner.mpolicy.algo.act(obs["observation"],goals=goal,use_goals=True)
            obs, _, done, infos = envs.step(act)
            if done: break

        if not video:
            if args.image:
                image_env(envs,0,all_index_goals[i:i+1], policy_save_dir, spe=None,args=args,obs=obs["observation"][0].numpy().transpose())
            else:
                image_env(envs,0,all_index_goals[i:i+1],policy_save_dir,spe=None,args=args)

            obs = envs.reset()
        if video:
            envs.close()
    if not video:
        envs.close()




def evaluate_robotic(env_name,args,policy,context,i=0):
    if args.env_name == "point":
        return
    # assert env_name == "sawyer_door"
    num_evals = 50
    if env_name == "sawyer_door" or env_name == "sawyer_pickup":
        object = context.object
        idx = np.random.randint(0, 500 if env_name == "sawyer_pickup" else 1000, num_evals)
        object = { k : v[idx] for k, v in object.items()}
        goals_states = object['state_desired_goal']
        if args.image:
            goals_obs = torch.from_numpy(object['image_desired_goal']).view(-1,3,48,48).float().to(args.device)
        else:
            goals_obs= torch.from_numpy(object['state_desired_goal']).float().to(args.device)


    # envs = make_dispatch(args.env_type, args.env_name,1, "cpu", max_steps=args.max_steps,seed=args.seed, video=False, size=args.size, args=args,reset=False)
    envs = make_dispatch(args.env_type,args,None,video=False,reset=False)

    mean_distances=0
    mean_min_distances = 0
    for i in range(num_evals):
        obs = envs.reset()
        if env_name == "sawyer_door" or env_name == "sawyer_pickup":
            goal_obs = goals_obs[i:i+1]
        if env_name == "sawyer_push":
            if not args.image:
                goal_obs = torch.from_numpy(obs["desired_goal"]).float().unsqueeze(0).to(args.device)
            else:
                goal_obs = torch.from_numpy(obs["image_desired_goal"]).view(-1,3,48,48).float().to(args.device)
                goal_obs = preprocess(goal_obs,args)

        goal_embed_obs=context.estimator.goal_embed(goal_obs,act=True)
        min_dist = 10000
        for _ in range(args.max_steps):

            act= policy.act(obs["observation"],goal_embed_obs,deterministic= args.deterministic_eval)
            obs,_,_,infos= envs.step(act)
            if env_name == "sawyer_door":
                act_dist= np.abs(envs.get_door_angle() - goals_states[i, -1])[0]
            elif env_name == "sawyer_push" :
                act_dist= infos["puck_distance"]
            elif env_name == "sawyer_pickup":
                act_dist= np.linalg.norm(envs.get_obj_pos() - goals_states[i,3:])
            min_dist = min(min_dist,act_dist)
        mean_min_distances+=min_dist

        if env_name == "sawyer_door":
            mean_distances += np.abs(envs.get_door_angle() - goals_states[i,-1])[0]
        elif env_name == "sawyer_pickup":
            mean_distances += np.linalg.norm(envs.get_obj_pos() - goals_states[i,3:])
        elif env_name == "sawyer_push":
            mean_distances += infos["puck_distance"]
    envs.close()

    logger_distances = logging.getLogger("distances")
    logger_distances.info(str(mean_distances/num_evals)+";"+str(mean_min_distances/num_evals))

def evaluate_maze(args,policy,context,save_dir):
    # assert env_name == "sawyer_door"
    num_evals = 50
    goals_obs = context.loaded_obs
    goals_states = context.loaded_pos
    envs = make_dispatch(args.env_type, args, save_dir=save_dir,video=False,reset=False)
    mean_distances=0
    for i in range(num_evals):
        obs = envs.reset()
        if args.env_name == "PointMaze1Pos-v1":
            goal_obs = preprocess(goals_states[i:i+1], args)
        else:
            goal_obs = preprocess(goals_obs[i:i+1],args)
        goal_embed_obs=context.estimator.goal_embed(goal_obs,act=True)
        predefined_state = goals_states[i:i + 1] if args.state else None

        for s in range(args.max_steps):
            state = torch.tensor(obs["state"]) if args.state else None
            goal, c = context.coord_actor.act(obs["observation"],predefined_goal=goal_embed_obs,predefined_state=predefined_state, state=state, step=s)
            act= policy.act(obs["observation"],goal,use_goals=True,deterministic= args.deterministic_eval)
            obs, _, _, infos = envs.step(act)
            mean_distances += np.linalg.norm(obs["achieved_goal"]-(goals_states[i:i+1]).numpy())
        image_env(envs, 0, torch.tensor([[i]]), save_dir, spe=None, args=args)


    envs.close()
    logger_distances = logging.getLogger("distances")
    logger_distances.info(mean_distances/num_evals)


def eval_maze(sublearner,coordlearner,save_dir,number_steps=100,args=None,**kwargs):

    #####Behaviour evaluation
    envs =make_dispatch(args.env_type, args, save_dir=save_dir,**kwargs)

    i=0
    all_success =0
    reward = 0
    resets = 0
    coordlearner.mpolicy.policy.eval_mode = True

    while i < number_steps:
        done=False
        success = 0
        obs = envs.reset()
        state = torch.tensor(obs["state"]) if args.state else None
        obs = preprocess(obs["observation"],args)
        resets += 1

        s=0
        while not done:
            goal, _ = coordlearner.mpolicy.policy.act(obs, step=s, state=state)
            if s == 0 and args.plan_interval != -1:
                store_path= coordlearner.mpolicy.policy.path
            action = sublearner.mpolicy.algo.act(obs,goal,deterministic=args.deterministic_eval, state=state)
            obs,r,done,infos = envs.step(action)
            state=torch.tensor(obs["state"]) if args.state else None
            obs=preprocess(obs["observation"],args)
            if infos["is_success"]:
                success = 1
            reward += r
            i=i+1
            s+=1

        if args.plan_interval != -1:
            logging.getLogger("plans").info(str(store_path)+" "+str(coordlearner.mpolicy.policy.executed_path))
        all_success += success
    coordlearner.mpolicy.policy.eval_mode = False
    all_success /= resets
    reward /= number_steps
    timestep = sublearner.mpolicy.total_num_steps
    if args.env_type == "maze":
        eval_logger = sublearner.mpolicy.context.eval_logger
        eval_logger.log_tabular("success",all_success)
        eval_logger.log_tabular("reward",reward.item())
        eval_logger.log_tabular("timestep",timestep)
        eval_logger.dump_tabular()


